{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "sa-mobilenetv3.ipynb",
      "provenance": [],
      "collapsed_sections": [
        "ruZ3_IrcWEWv",
        "sYw1h4w-x-TT",
        "KXV4M7d9xOuD",
        "c1nphOROso86",
        "JxDmFR5VYBLi",
        "k_j-p9yDs5tc"
      ]
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3.7.9 64-bit ('py37': conda)"
    },
    "language_info": {
      "name": "python",
      "version": "3.7.9"
    },
    "interpreter": {
      "hash": "6ab1cb544669fcae44c370d8b60f7937242068b17af606d40fcfd19ff71cacd4"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "MJiTrK4TAyV4"
      },
      "source": [
        "!pip install grad-cam\n",
        "!pip install ttach\n",
        "!pip install colorama"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "J9VqGsTHT6h0"
      },
      "source": [
        "import os\n",
        "import time\n",
        "from functools import partial\n",
        "from typing import Any, Callable, Dict, List, Optional, Sequence\n",
        "import math\n",
        "\n",
        "from colorama import Fore, Back, Style\n",
        "from tqdm import tqdm\n",
        "import matplotlib.pyplot as plt\n",
        "import cv2\n",
        "import numpy as np\n",
        "\n",
        "import torch\n",
        "from torch import nn, Tensor\n",
        "from torch.nn.parameter import Parameter\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import DataLoader, Dataset\n",
        "\n",
        "import torchvision\n",
        "from torchvision import datasets, transforms, models\n",
        "from torchvision.models.utils import load_state_dict_from_url\n",
        "from torchvision.models.mobilenetv2 import _make_divisible, ConvBNActivation"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "60mkHqwZhYQ6"
      },
      "source": [
        "# hyper parameters\n",
        "\n",
        "batch_size = 64\n",
        "num_workers = 2\n",
        "input_size = 112\n",
        "\n",
        "epochs = 5\n",
        "lr = 0.001\n",
        "embedding_size = 512\n",
        "\n",
        "save_model = True\n",
        "validation = True"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vZxoUKyMWLDc"
      },
      "source": [
        "# Dataset\n",
        "\n",
        "There are some ways to load the training data set"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qYJFnGk6V0yy"
      },
      "source": [
        "## 1- Mnist"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "pWAKXSa8h7lF"
      },
      "source": [
        "def gray2rgb(image):\n",
        "    return image.repeat(3, 1, 1)\n",
        "\n",
        "transform = transforms.Compose([transforms.ToTensor(),\n",
        "                                transforms.Lambda(gray2rgb),\n",
        "                                transforms.Resize((input_size,input_size)),\n",
        "                              ])\n",
        "\n",
        "train_dataset = datasets.MNIST('./datasets', train=True, download=True, transform=transform)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ruZ3_IrcWEWv"
      },
      "source": [
        "## 2- CIFAR10"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "kb-9Pis7WAaY"
      },
      "source": [
        "transform = transforms.Compose([transforms.ToTensor(),\n",
        "                                transforms.RandomHorizontalFlip(),\n",
        "                                transforms.Resize((input_size,input_size))\n",
        "                              ])\n",
        "\n",
        "train_dataset = datasets.CIFAR10('./datasets', train=True, download=True, transform=transform)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sYw1h4w-x-TT"
      },
      "source": [
        "## 3- Kaggle"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2xAg_bGCu3ot"
      },
      "source": [
        "!pip install -q kaggle\n",
        "!mkdir -p ~/.kaggle\n",
        "!cp kaggle.json ~/.kaggle/"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sMOKRvaBxZlh"
      },
      "source": [
        "!kaggle datasets list -s \"Gender Classification 200K Images | CelebA\""
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zW-a3movxY1U"
      },
      "source": [
        "!kaggle datasets download -d ashishjangra27/gender-recognition-200k-images-celeba\n",
        "!unzip -qq gender-recognition-200k-images-celeba.zip"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "YcuaJno-NWGQ"
      },
      "source": [
        "dataset_dir_path = '/content/Dataset/Train'\n",
        "\n",
        "transform = transforms.Compose([transforms.ToTensor(),\n",
        "                                # transforms.CenterCrop(input_size),\n",
        "                                transforms.RandomHorizontalFlip(),\n",
        "                                # transforms.RandomRotation(20),\n",
        "                                transforms.Resize((input_size,input_size)),\n",
        "                                # transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
        "                                #                      std=[0.229, 0.224, 0.225])\n",
        "                                ])\n",
        "\n",
        "train_dataset = torchvision.datasets.ImageFolder(root=dataset_dir_path, transform=transform)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KXV4M7d9xOuD"
      },
      "source": [
        "## 4- Google drive"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "m0v4T4cdxUyB"
      },
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/drive')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "p-e0DzXNjv0p"
      },
      "source": [
        "# dataset_dir_path = '/content/drive/MyDrive/datasets/Flowers/Train'\n",
        "dataset_dir_path = '/content/drive/MyDrive/datasets/CASIA-WebFace-mini-aligned'\n",
        "\n",
        "\n",
        "transform = transforms.Compose([transforms.ToTensor(),\n",
        "                                # transforms.CenterCrop(input_size),\n",
        "                                transforms.RandomHorizontalFlip(),\n",
        "                                # transforms.RandomRotation(20),\n",
        "                                transforms.Resize((input_size,input_size)),\n",
        "                                # transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
        "                                #                      std=[0.229, 0.224, 0.225])\n",
        "                                ])\n",
        "\n",
        "train_dataset = torchvision.datasets.ImageFolder(root=dataset_dir_path, transform=transform)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mMfNsoJXNsfv"
      },
      "source": [
        "## Load"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8uFliZOvNCrU"
      },
      "source": [
        "if validation:\n",
        "    train_size = int(0.8 * len(train_dataset))\n",
        "    val_size = len(train_dataset) - train_size\n",
        "    train_set, val_set = torch.utils.data.random_split(train_dataset, [train_size, val_size])\n",
        "    train_dataloader = DataLoader(train_set, num_workers=num_workers, shuffle=True, batch_size=batch_size)\n",
        "    val_dataloader = DataLoader(val_set, num_workers=num_workers, shuffle=True, batch_size=batch_size)\n",
        "    print(train_size, val_size)\n",
        "\n",
        "else:\n",
        "    train_dataloader = DataLoader(train_dataset, num_workers=num_workers, shuffle=True, batch_size=batch_size)\n",
        "    train_size = int(len(train_dataset))\n",
        "    print(train_size)\n",
        "\n",
        "num_classes = len(train_dataset.classes)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jk5z7ihnOqqn"
      },
      "source": [
        "data, target = iter(train_dataloader).next()\n",
        "\n",
        "print(target[0])\n",
        "print(data[0].shape)\n",
        "image = data[0].permute(1, 2, 0)\n",
        "plt.imshow(image)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SsBk5sNhD4x7"
      },
      "source": [
        "#Model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9LoHciYAyE7E"
      },
      "source": [
        "class sa_layer(nn.Module):\n",
        "    def __init__(self, channel, groups=4):\n",
        "        super(sa_layer, self).__init__()\n",
        "        self.groups = groups\n",
        "        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n",
        "        self.cweight = Parameter(torch.zeros(1, channel // (2 * groups), 1, 1))\n",
        "        self.cbias = Parameter(torch.ones(1, channel // (2 * groups), 1, 1))\n",
        "        self.sweight = Parameter(torch.zeros(1, channel // (2 * groups), 1, 1))\n",
        "        self.sbias = Parameter(torch.ones(1, channel // (2 * groups), 1, 1))\n",
        "\n",
        "        self.sigmoid = nn.Sigmoid()\n",
        "        self.gn = nn.GroupNorm(channel // (2 * groups), channel // (2 * groups))\n",
        "\n",
        "    @staticmethod\n",
        "    def channel_shuffle(x, groups):\n",
        "        b, c, h, w = x.shape\n",
        "\n",
        "        x = x.reshape(b, groups, -1, h, w)\n",
        "        x = x.permute(0, 2, 1, 3, 4)\n",
        "\n",
        "        # flatten\n",
        "        x = x.reshape(b, -1, h, w)\n",
        "        return x\n",
        "\n",
        "    def forward(self, x):\n",
        "        b, c, h, w = x.shape\n",
        "\n",
        "        x = x.reshape(b * self.groups, -1, h, w)\n",
        "        x_0, x_1 = x.chunk(2, dim=1)\n",
        "\n",
        "        # channel attention\n",
        "        xn = self.avg_pool(x_0)\n",
        "        xn = self.cweight * xn + self.cbias\n",
        "        xn = x_0 * self.sigmoid(xn)\n",
        "\n",
        "        # spatial attention\n",
        "        xs = self.gn(x_1)\n",
        "        xs = self.sweight * xs + self.sbias\n",
        "        xs = x_1 * self.sigmoid(xs)\n",
        "\n",
        "        # concatenate along channel axis\n",
        "        out = torch.cat([xn, xs], dim=1)\n",
        "        out = out.reshape(b, -1, h, w)\n",
        "\n",
        "        out = self.channel_shuffle(out, 2)\n",
        "        return out\n",
        "\n",
        "\n",
        "class InvertedResidualConfig:\n",
        "    def __init__(self, input_channels: int, kernel: int, expanded_channels: int, out_channels: int, use_se: bool,\n",
        "                 activation: str, stride: int, dilation: int, width_mult: float):\n",
        "        self.input_channels = self.adjust_channels(input_channels, width_mult)\n",
        "        self.kernel = kernel\n",
        "        self.expanded_channels = self.adjust_channels(expanded_channels, width_mult)\n",
        "        self.out_channels = self.adjust_channels(out_channels, width_mult)\n",
        "        self.use_se = use_se\n",
        "        self.use_hs = activation == \"HS\"\n",
        "        self.stride = stride\n",
        "        self.dilation = dilation\n",
        "\n",
        "    @staticmethod\n",
        "    def adjust_channels(channels: int, width_mult: float):\n",
        "        return _make_divisible(channels * width_mult, 8)\n",
        "\n",
        "\n",
        "class InvertedResidual(nn.Module):\n",
        "    def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Module]):\n",
        "        super().__init__()\n",
        "        if not (1 <= cnf.stride <= 2):\n",
        "            raise ValueError('illegal stride value')\n",
        "\n",
        "        self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels\n",
        "\n",
        "        layers: List[nn.Module] = []\n",
        "        activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU\n",
        "\n",
        "        # expand\n",
        "        if cnf.expanded_channels != cnf.input_channels:\n",
        "            layers.append(ConvBNActivation(cnf.input_channels, cnf.expanded_channels, kernel_size=1,\n",
        "                                           norm_layer=norm_layer, activation_layer=activation_layer))\n",
        "\n",
        "        # depthwise\n",
        "        stride = 1 if cnf.dilation > 1 else cnf.stride\n",
        "        layers.append(ConvBNActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel,\n",
        "                                       stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels,\n",
        "                                       norm_layer=norm_layer, activation_layer=activation_layer))\n",
        "        if cnf.use_se:\n",
        "            layers.append(sa_layer(cnf.expanded_channels))\n",
        "\n",
        "        # project\n",
        "        layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer,\n",
        "                                       activation_layer=nn.Identity))\n",
        "\n",
        "        self.block = nn.Sequential(*layers)\n",
        "        self.out_channels = cnf.out_channels\n",
        "        self._is_cn = cnf.stride > 1\n",
        "\n",
        "    def forward(self, input: Tensor) -> Tensor:\n",
        "        result = self.block(input)\n",
        "        if self.use_res_connect:\n",
        "            result += input\n",
        "        return result\n",
        "\n",
        "\n",
        "class MobileNetV3(nn.Module):\n",
        "    def __init__(self,\n",
        "                 inverted_residual_setting: List[InvertedResidualConfig],\n",
        "                 last_channel: int,\n",
        "                 num_classes: int = 1000,\n",
        "                 embedding_size: int = 512,\n",
        "                 block: Optional[Callable[..., nn.Module]] = None,\n",
        "                 norm_layer: Optional[Callable[..., nn.Module]] = None):\n",
        "        \"\"\"\n",
        "        MobileNet V3 main class\n",
        "        Args:\n",
        "            inverted_residual_setting (List[InvertedResidualConfig]): Network structure\n",
        "            last_channel (int): The number of channels on the penultimate layer\n",
        "            num_classes (int): Number of classes\n",
        "            block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for mobilenet\n",
        "            norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use\n",
        "        \"\"\"\n",
        "        super().__init__()\n",
        "\n",
        "        if not inverted_residual_setting:\n",
        "            raise ValueError(\"The inverted_residual_setting should not be empty\")\n",
        "        elif not (isinstance(inverted_residual_setting, Sequence) and\n",
        "                  all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting])):\n",
        "            raise TypeError(\"The inverted_residual_setting should be List[InvertedResidualConfig]\")\n",
        "\n",
        "        if block is None:\n",
        "            block = InvertedResidual\n",
        "\n",
        "        if norm_layer is None:\n",
        "            norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.01)\n",
        "\n",
        "        layers: List[nn.Module] = []\n",
        "\n",
        "        # building first layer\n",
        "        firstconv_output_channels = inverted_residual_setting[0].input_channels\n",
        "        layers.append(ConvBNActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer,\n",
        "                                       activation_layer=nn.Hardswish))\n",
        "\n",
        "        # building inverted residual blocks\n",
        "        for cnf in inverted_residual_setting:\n",
        "            layers.append(block(cnf, norm_layer))\n",
        "\n",
        "        # building last several layers\n",
        "        lastconv_input_channels = inverted_residual_setting[-1].out_channels\n",
        "        lastconv_output_channels = 6 * lastconv_input_channels\n",
        "        layers.append(ConvBNActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1,\n",
        "                                       norm_layer=norm_layer, activation_layer=nn.Hardswish))\n",
        "\n",
        "        self.features = nn.Sequential(*layers)\n",
        "        self.avgpool = nn.AdaptiveAvgPool2d(1)\n",
        "\n",
        "        self.classifier = nn.Sequential(\n",
        "            nn.Linear(lastconv_output_channels, last_channel),\n",
        "            nn.Hardswish(inplace=True),\n",
        "            nn.Dropout(p=0.2, inplace=True),\n",
        "            nn.Linear(last_channel, num_classes),\n",
        "        )\n",
        "\n",
        "        for m in self.modules():\n",
        "            if isinstance(m, nn.Conv2d):\n",
        "                nn.init.kaiming_normal_(m.weight, mode='fan_out')\n",
        "                if m.bias is not None:\n",
        "                    nn.init.zeros_(m.bias)\n",
        "            elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d, nn.GroupNorm)):\n",
        "                nn.init.ones_(m.weight)\n",
        "                nn.init.zeros_(m.bias)\n",
        "            elif isinstance(m, nn.Linear):\n",
        "                nn.init.normal_(m.weight, 0, 0.01)\n",
        "                nn.init.zeros_(m.bias)\n",
        "\n",
        "\n",
        "    def forward(self, x: Tensor) -> Tensor:\n",
        "        x = self.features(x)\n",
        "        x = self.avgpool(x)\n",
        "        x = torch.flatten(x, 1)\n",
        "        x = self.classifier(x)\n",
        "        return x\n",
        "\n",
        "\n",
        "def _mobilenet_v3_conf(arch: str, params: Dict[str, Any]):\n",
        "    # non-public config parameters\n",
        "    reduce_divider = 2 if params.pop('_reduced_tail', False) else 1\n",
        "    dilation = 2 if params.pop('_dilated', False) else 1\n",
        "    width_mult = params.pop('_width_mult', 1.0)\n",
        "\n",
        "    bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)\n",
        "    adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult)\n",
        "\n",
        "    if arch == \"mobilenet_v3_large\":\n",
        "        inverted_residual_setting = [\n",
        "            bneck_conf(16, 3, 16, 16, False, \"RE\", 1, 1),\n",
        "            bneck_conf(16, 3, 64, 24, False, \"RE\", 2, 1),  # C1\n",
        "            bneck_conf(24, 3, 72, 24, False, \"RE\", 1, 1),\n",
        "            bneck_conf(24, 5, 72, 40, True, \"RE\", 2, 1),  # C2\n",
        "            bneck_conf(40, 5, 120, 40, True, \"RE\", 1, 1),\n",
        "            bneck_conf(40, 5, 120, 40, True, \"RE\", 1, 1),\n",
        "            bneck_conf(40, 3, 240, 80, False, \"HS\", 2, 1),  # C3\n",
        "            bneck_conf(80, 3, 200, 80, False, \"HS\", 1, 1),\n",
        "            bneck_conf(80, 3, 184, 80, False, \"HS\", 1, 1),\n",
        "            bneck_conf(80, 3, 184, 80, False, \"HS\", 1, 1),\n",
        "            bneck_conf(80, 3, 480, 112, True, \"HS\", 1, 1),\n",
        "            bneck_conf(112, 3, 672, 112, True, \"HS\", 1, 1),\n",
        "            bneck_conf(112, 5, 672, 160 // reduce_divider, True, \"HS\", 2, dilation),  # C4\n",
        "            bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, \"HS\", 1, dilation),\n",
        "            bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, \"HS\", 1, dilation),\n",
        "        ]\n",
        "        last_channel = adjust_channels(1280 // reduce_divider)  # C5\n",
        "    elif arch == \"mobilenet_v3_small\":\n",
        "        inverted_residual_setting = [\n",
        "            bneck_conf(16, 3, 16, 16, True, \"RE\", 2, 1),  # C1\n",
        "            bneck_conf(16, 3, 72, 24, False, \"RE\", 2, 1),  # C2\n",
        "            bneck_conf(24, 3, 88, 24, False, \"RE\", 1, 1),\n",
        "            bneck_conf(24, 5, 96, 40, True, \"HS\", 2, 1),  # C3\n",
        "            bneck_conf(40, 5, 240, 40, True, \"HS\", 1, 1),\n",
        "            bneck_conf(40, 5, 240, 40, True, \"HS\", 1, 1),\n",
        "            bneck_conf(40, 5, 120, 48, True, \"HS\", 1, 1),\n",
        "            bneck_conf(48, 5, 144, 48, True, \"HS\", 1, 1),\n",
        "            bneck_conf(48, 5, 288, 96 // reduce_divider, True, \"HS\", 2, dilation),  # C4\n",
        "            bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, \"HS\", 1, dilation),\n",
        "            bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, \"HS\", 1, dilation),\n",
        "        ]\n",
        "        last_channel = adjust_channels(1024 // reduce_divider)  # C5\n",
        "    else:\n",
        "        raise ValueError(\"Unsupported model type {}\".format(arch))\n",
        "\n",
        "    return inverted_residual_setting, last_channel\n",
        "\n",
        "\n",
        "def _mobilenet_v3_model(\n",
        "    arch: str,\n",
        "    inverted_residual_setting: List[InvertedResidualConfig],\n",
        "    last_channel: int,\n",
        "    pretrained: bool,\n",
        "    progress: bool,\n",
        "    **kwargs: Any):\n",
        "    model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs)\n",
        "    return model\n",
        "\n",
        "\n",
        "def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:\n",
        "    arch = \"mobilenet_v3_large\"\n",
        "    inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, kwargs)\n",
        "    return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs)\n",
        "\n",
        "\n",
        "def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:\n",
        "    arch = \"mobilenet_v3_small\"\n",
        "    inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, kwargs)\n",
        "    return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "W7NDwItSUDst"
      },
      "source": [
        "device = torch.device('cuda')\n",
        "model = mobilenet_v3_large(embedding_size=embedding_size, num_classes=num_classes).to(device)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2P2UJ2ZqV7yz"
      },
      "source": [
        "#Train"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "aIRetToEQCFN"
      },
      "source": [
        "loss_fn = torch.nn.CrossEntropyLoss()\n",
        "optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n",
        "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QrdnEytR-CfE"
      },
      "source": [
        "def calc_acc(preds: torch.Tensor, labels: torch.Tensor):\n",
        "    _, pred_max = torch.max(preds, 1)\n",
        "    acc = torch.sum(pred_max == labels.data, dtype=torch.float64) / len(preds)\n",
        "    return acc"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WRHxuXzWp6fW"
      },
      "source": [
        "tic = time.time()\n",
        "\n",
        "for epoch in range(1, epochs + 1):\n",
        "    model.train(True)\n",
        "    train_loss = 0.0\n",
        "    train_acc = 0.0\n",
        "    for images, labels in tqdm(train_dataloader, desc=\"Training\"):\n",
        "        images, labels = images.to(device), labels.to(device)\n",
        "        optimizer.zero_grad()\n",
        "\n",
        "        preds = model(images)\n",
        "        loss = loss_fn(preds, labels)\n",
        "        \n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "\n",
        "        train_loss += loss\n",
        "        train_acc += calc_acc(preds, labels)\n",
        "\n",
        "    total_loss = train_loss / len(train_dataloader)\n",
        "    total_acc = train_acc / len(train_dataloader)\n",
        "    print(Fore.GREEN, f\"Epoch: {epoch} [Train Loss: {total_loss}] [Train Accuracy: {total_acc}] [lr: {optimizer.param_groups[0]['lr']}]\", Fore.RESET)\n",
        "\n",
        "    if validation:\n",
        "        model.eval()\n",
        "        with torch.no_grad():\n",
        "            val_loss = 0.0\n",
        "            val_acc = 0.0\n",
        "            for images, labels in tqdm(val_dataloader, desc=\"Validating\"):\n",
        "                images, labels = images.to(device), labels.to(device)\n",
        "                \n",
        "                preds = model(images)\n",
        "                loss = loss_fn(preds, labels)\n",
        "                val_loss += loss\n",
        "                val_acc += calc_acc(preds, labels)\n",
        "\n",
        "            total_loss = val_loss / len(val_dataloader)\n",
        "            # total_loss = 0\n",
        "            total_acc = val_acc / len(val_dataloader)\n",
        "            print(Fore.BLUE, f\"[Validation Loss: {total_loss}] [Validation Accuracy: {total_acc}]\", Fore.RESET)\n",
        "\n",
        "    scheduler.step()\n",
        "\n",
        "tac = time.time()\n",
        "print(\"Time Taken : \", tac - tic)\n",
        "\n",
        "torch.save(model.state_dict(), \"weights.pth\")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DfDeS37zqPbA"
      },
      "source": [
        "#Test"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "C2fIKbuFSUgT"
      },
      "source": [
        "torch.cuda.empty_cache()\n",
        "model.eval()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fG8JV7n_XUoO"
      },
      "source": [
        "model.load_state_dict(torch.load(\"weights.pth\", map_location=device), strict=False)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Af7F2NsyVupo"
      },
      "source": [
        "# Dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "c1nphOROso86"
      },
      "source": [
        "## 1- MNIST"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dP0WVr2LOUjG"
      },
      "source": [
        "def gray2rgb(image):\n",
        "    return image.repeat(3, 1, 1)\n",
        "\n",
        "transform = transforms.Compose([transforms.ToTensor(),\n",
        "                                transforms.Lambda(gray2rgb),\n",
        "                                transforms.Resize((input_size,input_size)),\n",
        "                              ])\n",
        "\n",
        "test_dataset = datasets.MNIST('.', train=False, download=True, transform=transform)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JxDmFR5VYBLi"
      },
      "source": [
        "## 2- CFAR10"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "yKAIjrtJYArt"
      },
      "source": [
        "transform = transforms.Compose([transforms.ToTensor(),\n",
        "                                transforms.Resize((input_size,input_size))\n",
        "                              ])\n",
        "\n",
        "test_dataset = datasets.CIFAR10('.', train=False, download=True, transform=transform)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "k_j-p9yDs5tc"
      },
      "source": [
        "## 3- Google drive"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "uABYPDOvsr-C"
      },
      "source": [
        "dataset_dir_path = '/content/drive/MyDrive/datasets/Flowers/Test'\n",
        "\n",
        "transform = transforms.Compose([transforms.ToTensor(),\n",
        "                                transforms.Resize((input_size,input_size))\n",
        "                                ])\n",
        "\n",
        "test_dataset = torchvision.datasets.ImageFolder(root=dataset_dir_path, transform=transform)\n",
        "test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)\n",
        "\n",
        "num_classes = len(test_dataset.classes)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7yRfiPNBksjp"
      },
      "source": [
        "## Load"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MKKUmQcwWF26"
      },
      "source": [
        "test_dataloader = torch.utils.data.DataLoader(test_dataset, num_workers=num_workers, shuffle=False, batch_size=batch_size)\n",
        "num_classes = len(test_dataset.classes)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6WU8FUElf3Zb"
      },
      "source": [
        "data, target = iter(test_dataloader).next()\n",
        "\n",
        "print(target[0])\n",
        "print(data[0].shape)\n",
        "image = data[0].permute(1, 2, 0)\n",
        "plt.imshow(image)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2yleIUk0CgxN"
      },
      "source": [
        "test_acc = 0.0\n",
        "\n",
        "with torch.no_grad():\n",
        "    for images, labels in tqdm(test_dataloader, desc=\"Testing\"):\n",
        "        images, labels = images.to(device), labels.to(device)\n",
        "        \n",
        "        preds = model(images)\n",
        "        test_acc += calc_acc(preds, labels)\n",
        "\n",
        "acc = test_acc / len(test_dataloader)\n",
        "print(Fore.BLUE, f\"[Test Accuracy: {acc}]\", Fore.RESET)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sZiHIx9QCp19"
      },
      "source": [
        "#Inference"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "KZljPyz5RkMI"
      },
      "source": [
        "test_transform = transforms.Compose([transforms.ToTensor(),\n",
        "                                     transforms.Resize((input_size, input_size))\n",
        "                                    ])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TRkkopQon3vT"
      },
      "source": [
        "image = cv2.imread('/content/image.jpg')\n",
        "image = cv2.resize(image, (input_size, input_size))\n",
        "\n",
        "with torch.no_grad():\n",
        "    image_tensor = test_transform(image).unsqueeze(0).to(device)\n",
        "    preds = model(image_tensor)\n",
        "    output = torch.argmax(preds, dim=1)\n",
        "    print(output)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "lNuFnXZHncNf"
      },
      "source": [
        "from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM\n",
        "from pytorch_grad_cam.utils.image import show_cam_on_image\n",
        "\n",
        "target_layer = model.features[-1]\n",
        "cam = GradCAM(model=model, target_layer=target_layer, use_cuda=True)\n",
        "\n",
        "input_tensor = image_tensor.float()\n",
        "grayscale_cam = cam(input_tensor=image_tensor, target_category=None, aug_smooth=True)\n",
        "\n",
        "grayscale_cam = grayscale_cam[0]"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "caGu-9CahpF3"
      },
      "source": [
        "image_normal = image / 255.0\n",
        "visualization = show_cam_on_image(image_normal, grayscale_cam)\n",
        "cv2.imwrite('/content/output.jpg', visualization)\n",
        "\n",
        "plt.figure()\n",
        "plt.imshow(image)\n",
        "plt.figure()\n",
        "plt.imshow(visualization)"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}